# _*_ coding: utf-8 _*_
# @Date : 2023/3/24 17:10
# @Author : Paul
# @File : XGB_regressor.py
# @Description :
import matplotlib.pyplot as plt
import numpy as np
import sys
import json

from core.beans.param_train_result import ParamTrainResult

from core.utils.log_util import LogUtil
from core.beans.regress_result import RegressionResult
from core.utils.string_utils import StringUtils
from core.utils.date_util import DateUtil
from regressions.regression import Regression
from xgboost import XGBRegressor as XGBR


class TWXGBRegressor(Regression):

    def __init__(self,
                 app_name="linear_regression",
                 data_source_id=None,
                 table_name=None,
                 feature_cols=None,
                 class_col=None,
                 train_size=None,
                 param=None
                 ):
        """
        初始化
        :param app_name:
        :param data_source_id:
        :param table_name:
        :param feature_cols:
        :param class_col:
        :param train_size:
        """
        super(TWXGBRegressor, self).__init__(app_name=app_name,
                                                   data_source_id=data_source_id,
                                                   table_name=table_name,
                                                   feature_cols=feature_cols,
                                                   class_col=class_col,
                                                   train_size=train_size,
                                                   param=param)
        self.IS_MODEL_EVAL = True  # 默认：不需要评估模型
        # 预测效果分布图
        self.tree_pred_image = "descion_tree_pred_" + app_name + "_" + DateUtil.getCurrentDateSimple()
        # 画图数据条数
        self.plot_rows_num = 200

    def initModel(self):
        """
        初始化模型
        """
        algoParam = self.param["algoParam"]
        n_estimators = 100 if StringUtils.isBlack(algoParam["nEstimators"]) else int(algoParam["nEstimators"])
        random_state = 0 if StringUtils.isBlack(algoParam["randomState"]) else int(algoParam["randomState"])

        self.reg = XGBR(n_estimators=n_estimators,
                        random_state=random_state)

    def buildModel(self, train_data):
        """
        训练模型
        """
        Xtrain = train_data[0]
        Ytrain = train_data[1]
        self.reg = self.reg.fit(Xtrain, Ytrain)


    def evalModel(self, train_data, test_data):
        """
        评估模型
        """
        Xtest = test_data[0]
        Ytest = test_data[1]
        Ytest_predict = self.reg.predict(Xtest)
        from sklearn.metrics import mean_squared_error as MSE
        self.score_ = np.sqrt(MSE(Ytest, Ytest_predict))

        var_importance = [*zip(self.feature_cols, self.reg.feature_importances_)]

        test_data_rows = len(Xtest)

        # 汉字字体，优先使用楷体，找不到则使用黑体
        plt.rcParams['font.sans-serif'] = ['Kaitt', 'SimHei']

        # 正常显示负号
        plt.rcParams['axes.unicode_minus'] = False
        plt.figure()
        if test_data_rows < self.plot_rows_num:
            plt.plot(np.linspace(0.05, 1, test_data_rows), Ytest, "green", label="Y-真实值")
            plt.plot(np.linspace(0.05, 1, test_data_rows), Ytest_predict, "red", label="Y-预测值")
        else:
            plt.plot(np.linspace(0.05, 1, self.plot_rows_num), Ytest[:self.plot_rows_num], "green", label="Y-真实值")
            plt.plot(np.linspace(0.05, 1, self.plot_rows_num), Ytest_predict[:self.plot_rows_num], "red", label="Y-预测值")

        plt.legend()
        plt.savefig(self.regress_pred_image, dpi=300)
        plt.show()

        # 结束时间
        end_time = DateUtil.getCurrentDate()
        cost_second = DateUtil.diffMin(self.start_time, end_time)

        # 模型结果存入mysql
        algo_result = RegressionResult(self.param["id"],
                                       "XGB_regressor",
                                       self.param,
                                       self.app_name,
                                       self.info,
                                       self.describe,
                                       self.regress_pred_image,
                                       var_importance,
                                       self.score_,
                                       "sucess",
                                       self.start_time,
                                       end_time,
                                       cost_second)
        LogUtil.saveRegressionResult(self.meta_data_source, algo_result)

    def paramTrain(self):
        """
        超参数训练
        :return:
        """
        # 获取超参数训练参数
        param_train = self.param["paramTrain"]
        param_name = None if StringUtils.isBlack(param_train["paramName"]) else str(param_train["paramName"])
        param_start_value = None if StringUtils.isBlack(param_train["paramStartValue"]) else int(param_train["paramStartValue"])
        param_end_value = None if StringUtils.isBlack(param_train["paramEndValue"]) else int(param_train["paramEndValue"])
        param_range_value = None if StringUtils.isBlack(param_train["paramRangeValue"]) else int(param_train["paramRangeValue"])

        #测试数据
        train_data, test_data = self.getModelData()
        Xtrain = train_data[0]
        Ytrain = train_data[1]
        Xtest = test_data[0]
        Ytest = test_data[1]
        eval_value_list = []
        if param_name == None or param_start_value is None or param_end_value is None or param_range_value is None or param_start_value==1:
            error_info = "请确认参数必须为整数，且参数起始值不能为1"

            # 结束时间
            end_time = DateUtil.getCurrentDate()
            cost_second = DateUtil.diffMin(self.start_time, end_time)
            # 模型结果存入mysql
            param_train_result = ParamTrainResult(self.param["id"],
                                                  "XGB_regressor",
                                                  self.param,
                                                  self.app_name,
                                                  error_info,
                                                  "failed",
                                                  self.start_time,
                                                  end_time,
                                                  cost_second)
            LogUtil.saveParamTrainResult(self.meta_data_source, param_train_result)
        else:
            for param_value in range(param_start_value, param_end_value, param_range_value):
                if param_name == "max_depth":
                    self.reg = XGBR(max_depth=param_value)
                elif param_name == "n_estimators":
                    self.reg = XGBR(n_estimators=param_value)
                    self.reg = self.reg.fit(Xtrain, Ytrain)
                Ytest_predict = self.reg.predict(Xtest)
                from sklearn.metrics import mean_squared_error as MSE
                score_ = np.sqrt(MSE(Ytest, Ytest_predict))
                eval_value_list.append(score_)
            # 保存结果
            param_train_image = self.image_path + "descion_tree_param_train_" + DateUtil.getCurrentDateSimple() + ".png"
            fig, ax = plt.subplots(1, 1)
            plot_titile = ""
            if param_name == "max_depth":
                plot_titile = "最大数据深度--RMSE--超参数学习曲线"
            elif param_name == "n_estimators":
                plot_titile = "基评估器的数量--RMSE--超参数学习曲线"
            ax.set_title(plot_titile)
            ax.plot([i for i in range(param_start_value, param_end_value, param_range_value)], eval_value_list)
            plt.savefig(param_train_image, dpi=300)
            # plt.show()

            # 结束时间
            end_time = DateUtil.getCurrentDate()
            cost_second = DateUtil.diffMin(self.start_time, end_time)
            # 模型结果存入mysql
            param_train_result = ParamTrainResult(self.param["id"],
                                               "XGB_regressor",
                                                self.param,
                                                self.app_name,
                                                param_train_image,
                                                "success",
                                                self.start_time,
                                                end_time,
                                                cost_second)
            LogUtil.saveParamTrainResult(self.meta_data_source, param_train_result)


if __name__ == '__main__':
    argv = sys.argv[1]
    # argv = "{\"algoParam\":{\"nEstimators\":\"100\",\"randomState\":\"0\"},\"appName\":\"kmeans_1\",\"classCols\":\"housePrice\",\"dataSourceId\":\"9\",\"featureCols\":\"MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude\",\"id\":\"1679540469249\",\"preProcessMethodList\":[{\"preProcessMethod\":\"deletena\"}],\"tableName\":\"california_housing\",\"trainDataRatio\":\"0.7\"}"
    param = json.loads(argv)
    app_name = param["appName"]
    data_source_id = param["dataSourceId"]
    table_name = param["tableName"]
    feature_cols = param["featureCols"]
    class_cols = param["classCols"]
    train_size = float(param["trainDataRatio"])
    class_cols_list = []
    if isinstance(class_cols, list):
        class_cols_list = class_cols
    else:
        class_cols_list.append(class_cols)
    classifier = TWXGBRegressor(app_name=app_name,
                                      data_source_id=data_source_id,
                                      table_name=table_name,
                                      feature_cols=feature_cols,
                                      class_col=class_cols_list,
                                      train_size=train_size,
                                      param=param)
    if "paramTrain" not in param.keys():
        classifier.execute()
    else:
        classifier.paramTrain()